# Author:  Sebastian Ziaja
# Paper:   "More donors, more democracy"
# Date:    2017-04-15 
# Purpose: Personalize extract methods for texreg output.
# 

detach("package:dplyr")
require(texreg)
require(AER)
require(lfe)
require(clusterSEs)
require(lfe)
require(dplyr)

# ivreg ############

extract.ivreg.cstm <- function (model, include.rsquared = TRUE, include.adjrs = FALSE, 
    include.nobs = TRUE, include.fstatistic = FALSE, include.rmse = FALSE, 
    include.fe = TRUE, include.diag = TRUE, clusterse = TRUE,
    ...) 
{   s <- summary(model, diagnostics = TRUE, ...) 
    # model <- r1[[1]]; s <- summary(model, diagnostics = T)
    # include.rsquared = TRUE; include.adjrs = TRUE;
    # include.nobs = TRUE; include.fstatistic = FALSE; include.rmse = TRUE;
    # include.fe = TRUE; include.diag = TRUE; clusterse = TRUE
    names <- rownames(s$coef)
    co <- s$coef[, 1]
    se <- s$coef[, 2]
    pval <- s$coef[, 4]
    rs <- s$r.squared
    adj <- s$adj.r.squared
    n <- nobs(model)
    gof <- numeric()
    gof.names <- character()
    gof.decimal <- logical()
    
    if (include.fe == TRUE) {
        cfe <- ifelse(sum(str_count(model$terms, "cname|gwno")) > 0,
                      1111111, 9999999) # texreg gofs do not accept strings
        pfe <- ifelse(sum(str_count(model$terms, "period")) > 0,
                      1111111, 9999999)
        gof <- c(gof, cfe, pfe)
        gof.names <- c(gof.names, "Country fixed effects", "Period fixed effects")
        gof.decimal <- c(gof.decimal, FALSE, FALSE)
    }
    if (include.rsquared == TRUE) {
        gof <- c(gof, rs)
        gof.names <- c(gof.names, "R$^2$")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.adjrs == TRUE) {
        gof <- c(gof, adj)
        gof.names <- c(gof.names, "Adj. R$^2$")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.nobs == TRUE) {
        gof <- c(gof, n)
        gof.names <- c(gof.names, "Num. obs.")
        gof.decimal <- c(gof.decimal, FALSE)
    }
    if (include.fstatistic == TRUE) {
        fstat <- s$fstatistic[[1]]
        gof <- c(gof, fstat)
        gof.names <- c(gof.names, "F statistic")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.rmse == TRUE && !is.null(s$sigma[[1]])) {
        rmse <- s$sigma[[1]]
        gof <- c(gof, rmse)
        gof.names <- c(gof.names, "RMSE")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    
    if (include.diag) {
        niv <- sum(str_count(model$terms$instruments, "Z"))
        
        gof <- c(gof, s$diagnostics[, 3][1:(#1 + 
                                                niv)])
        gof.names <- c(gof.names,
                       if (niv == 1) {
                                c("KP F-statistic"
                                  # , "Wu-Hausman statistic"
                                  ) } else {
                                c("KP F-statistic",
                                  "...for the second instrument"
                                  # , "Wu-Hausman statistic"
                                  ) } )
        gof.decimal <- c(gof.decimal, c(# TRUE, 
                                        rep(TRUE, niv)))
    }
    
    # if (clusterse) { # ife:felm() has clustered s.e. implemented
    #     vx <- all.vars(s$call)[-grep("^\\.$|dX", all.vars(s$call))]
    #     cluster.wild.ivreg(mod = model,
    #                      cluster = ~cname,
    #                      dat = get(all.vars(s$call)[length(all.vars(s$call))]),
    #                      boot.reps = 100)
    #     
    #     
    # }
    
    tr <- createTexreg(coef.names = names, coef = co, se = se, 
        pvalues = pval, gof.names = gof.names, gof = gof, gof.decimal = gof.decimal)
    # sr("tr")
    return(tr)
}

setMethod(
    "extract"
    , signature = className("ivreg", "AER")
    , definition = extract.ivreg.cstm
)

# lm ##################

extract.lm.cstm <- function (model, include.rsquared = TRUE, include.adjrs = FALSE, 
    include.nobs = TRUE, include.fstatistic = FALSE, include.rmse = FALSE, 
    include.fe = TRUE, clusterse = NULL, controls = FALSE,
    ...) 
{
    s <- summary(model, ...) 
    # model <- r1[[1]]; s <- summary(model)
    # include.rsquared = TRUE; include.adjrs = FALSE;
    # include.nobs = TRUE; include.fstatistic = FALSE; include.rmse = FALSE;
    # include.fe = TRUE; clusterse = "cnamef"
    names <- rownames(s$coef)
    co <- s$coef[, 1]
    se <- s$coef[, 2]
    pval <- s$coef[, 4]
    rs <- s$r.squared
    adj <- s$adj.r.squared
    n <- nobs(model)
    gof <- numeric()
    gof.names <- character()
    gof.decimal <- logical()
    
    if (include.fe == TRUE) {
        cfe <- ifelse(sum(str_count(model$terms, "cname|gwno")) > 0,
                      1111111, 9999999) # texreg gofs do not accept strings
        pfe <- ifelse(sum(str_count(model$terms, "period")) > 0,
                      1111111, 9999999)
        gof <- c(gof, cfe, pfe)
        gof.names <- c(gof.names, "Country fixed effects", "Period fixed effects")
        gof.decimal <- c(gof.decimal, FALSE, FALSE)
    }
    
    if (controls) {
        cin <- ifelse(sum(str_count(model$terms, "pop.l|gdpcap.l")) > 0,
                      1111111, 9999999) # texreg gofs do not accept strings
        gof <- c(gof, cin)
        gof.names <- c(gof.names, "Additional controls")
        gof.decimal <- c(gof.decimal, FALSE)
    }
    
    if (include.rsquared == TRUE) {
        gof <- c(gof, rs)
        gof.names <- c(gof.names, "R$^2$")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.adjrs == TRUE) {
        gof <- c(gof, adj)
        gof.names <- c(gof.names, "Adj. R$^2$")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.nobs == TRUE) {
        gof <- c(gof, n)
        gof.names <- c(gof.names, "\\# of observations")
        gof.decimal <- c(gof.decimal, FALSE)
    }
    if (!is.null(clusterse)) {
        
        stop("clustered standard errors not implemented")
        
        # # cross-checked with Stata results: wrong!
        # 
        # # from http://www.richard-bluhm.com/clustered-ses-in-r-and-stata-2/
        # # compute Stata like df-adjustment
        # G <- length(unique(d$cnamef))
        # N <- nrow(model$model)
        # dfa <- (G/(G - 1)) * (N - 1) / model$df.residual
        # 
        # # add se and p-val with cluster VCE and df-adjustment to output
        # c_vcov <- dfa * vcovHC(model, type = "HC0", cluster = clusterse, adjust = T)
        # ccoef <- coeftest(model, vcov = c_vcov)
        # 
        # se <- ccoef[, 2]
        # pval <- ccoef[, 4]
        # 
        # gof <- c(gof, G)
        # gof.names <- c(gof.names, "\\# of countries")
        # gof.decimal <- c(gof.decimal, FALSE)
        
    }
    if (include.fstatistic == TRUE) {
        fstat <- s$fstatistic[[1]]
        gof <- c(gof, fstat)
        gof.names <- c(gof.names, "F statistic")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.rmse == TRUE && !is.null(s$sigma[[1]])) {
        rmse <- s$sigma[[1]]
        gof <- c(gof, rmse)
        gof.names <- c(gof.names, "RMSE")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    

    
    tr <- createTexreg(coef.names = names, coef = co, se = se, 
        pvalues = pval, gof.names = gof.names, gof = gof, gof.decimal = gof.decimal)
    return(tr)
}

setMethod(
    "extract"
    , signature = className("lm", "base")
    , definition = extract.lm.cstm
)

# felm ###############

extract.felm <- function(model, include.nobs = TRUE, include.rsquared = TRUE, 
    include.adjrs = FALSE, include.fstatistic = FALSE, include.kp = TRUE, 
    include.fe = TRUE, include.nc = TRUE, controls = TRUE, ci.clust = TRUE, ...) {
  
  s <- summary(model)
  # model <- r[[2]]; s <- summary(model)
  # include.nobs = TRUE; include.rsquared = TRUE
  # include.adjrs = TRUE; include.fstatistic = FALSE; include.kp = TRUE;
  # include.fe = TRUE; include.nc = TRUE; controls = TRUE; ci.clust = TRUE
  
  nam <- rownames(s$coefficients)
  co <- s$coefficients[, 1]
  se <- s$coefficients[, 2]
  pval <- s$coefficients[, 4]
  niv <- length(model$endovars)
  
  gof <- numeric()
  gof.names <- character()
  gof.decimal <- logical()
  
  if (include.fe == TRUE) {
      cfe <- if (length(names(model$fe)) == 0 ) { 9999999 } else 
          if ("cnamef" %in% names(model$fe)) { 1111111 } else { 9999999 }
      # texreg gofs do not accept strings
      pfe <- if (length(names(model$fe)) == 0 ) { 9999999 } else 
          if ("periodf" %in% names(model$fe)) { 1111111 } else { 9999999 }
      gof <- c(gof, cfe, pfe)
      gof.names <- c(gof.names, "Country fixed effects", "Period fixed effects")
      gof.decimal <- c(gof.decimal, FALSE, FALSE)
  }
  if (controls) {
      cin <- ifelse(sum(str_count(rownames(model$coefficients), "pop.l|gdpcap.l")) > 0,
                    1111111, 9999999) # texreg gofs do not accept strings
      gof <- c(gof, cin)
      gof.names <- c(gof.names, "Additional controls")
      gof.decimal <- c(gof.decimal, FALSE)
      
  }
  if (include.nobs == TRUE) {
    gof <- c(gof, s$N)
    gof.names <- c(gof.names, "\\# of observations")
    gof.decimal <- c(gof.decimal, FALSE)
  }
  
  if (include.nc) {
      gof <- c(gof, nlevels(model$clustervar$cnamef))
      gof.names <- c(gof.names, "\\# of countries")
      gof.decimal <- c(gof.decimal, FALSE)
  }
  if (include.rsquared == TRUE) {
    gof <- c(gof, s$r2) # , s$P.r.squared)
    gof.names <- c(gof.names, "R$^2$") # (full model)") #, "R$^2$ (proj model)")
    gof.decimal <- c(gof.decimal, TRUE)# , TRUE)
  }
  if (include.adjrs == TRUE) {
    gof <- c(gof, s$r2adj) #, s$P.adj.r.squared)
    gof.names <- c(gof.names, "Adj.\ R$^2$") # (full model)") 
    # , "Adj.\ R$^2$ (proj model)")
    gof.decimal <- c(gof.decimal, TRUE) #, TRUE)
  }
  if (include.fstatistic == TRUE) {
    gof <- c(gof, s$F.fstat[1], s$F.fstat[4], 
        s$P.fstat[length(s$P.fstat) - 1], s$P.fstat[1])
    gof.names <- c(gof.names, "F statistic (full model)", 
        "F (full model): p-value", "F statistic (proj model)", 
        "F (proj model): p-value")
    gof.decimal <- c(gof.decimal, TRUE, TRUE, TRUE, TRUE)
  }
  
  if (niv > 0 & include.kp == TRUE) {
      ncoef1s <- nrow(model$stage1$coefficients) 
      ivvar <- rownames(model$stage1$coefficients)[c(ncoef1s - 
                                                         rev(c(1:niv) - 1))]
      ivF <- sapply(model$stage1$lhs, function(lh) 
          lfe::waldtest(model$stage1, 
                        as.formula(paste0("~", ivvar[1], 
                                          ifelse(niv == 1, "",
                                                 paste0("|", ivvar[2])))), 
                        lhs=lh))
      gof <- c(gof, ivF[5, ])
      gof.names <- c(gof.names, 
                     if (niv == 1) { "KP F-statistic" } else {
                         c("KP F-statistic",
                              "...for the second instrument") } )
      gof.decimal <- c(gof.decimal,
                       if (niv == 1) { TRUE } else
                           { c(TRUE, TRUE) } )
  }  
  
  tr <- createTexreg(
      coef.names = nam, 
      coef = co, 
      se = se, 
      pvalues = pval, 
      gof.names = gof.names, 
      gof = gof, 
      gof.decimal = gof.decimal
  )
  return(tr)
}

setMethod("extract", signature = className("felm", "lfe"), 
    definition = extract.felm)

# lrm ##################

extract.lrm <- function (model, include.pseudors = TRUE, include.lr = TRUE, 
    include.nobs = TRUE, cluster = "dyad", data = ddx, include.nc = TRUE, ...) 
{
    # model <- r[[1]]
    # include.pseudors = TRUE; include.lr = TRUE;
    # include.nobs = TRUE; cluster = "dyad"; data = ddx; include.nc = TRUE
    
    attributes(model$coef)$names <- lapply(attributes(model$coef)$names, 
        function(x) gsub(">=", " $\\\\geq$ ", x))
    coef.names <- attributes(model$coef)$names
    coef <- model$coef
    se <- sqrt(diag(model$var))
    p <- pnorm(abs(model$coef/sqrt(diag(model$var))), lower.tail = FALSE) * 
        2
    gof <- numeric()
    gof.names <- character()
    gof.decimal <- logical()
    if (include.nobs == TRUE) {
        n <- model$stats[1]
        gof <- c(gof, n)
        gof.names <- c(gof.names, "\\# of obs.")
        gof.decimal <- c(gof.decimal, FALSE)
    }
    
    if (include.nc) {
      gof <- c(gof, length(unique(as.data.frame(data)[, cluster])))
      gof.names <- c(gof.names, "\\# of dyads")
      gof.decimal <- c(gof.decimal, FALSE)
    }
    
    
    if (include.pseudors == TRUE) {
        pseudors <- model$stats[10]
        gof <- c(gof, pseudors)
        gof.names <- c(gof.names, "Pseudo R$^2$")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (include.lr == TRUE) {
        LR <- model$stats[3]
        gof <- c(gof, LR)
        gof.names <- c(gof.names, "L.R.")
        gof.decimal <- c(gof.decimal, TRUE)
    }
    if (!is.null(cluster)) {
        rc <- robcov(model, cluster = as.data.frame(data)[, cluster])
        se <- sqrt(diag(rc$var))
        p <- pnorm(abs(model$coef / se), lower.tail = FALSE) * 2
    }
    tr <- createTexreg(coef.names = coef.names, coef = coef, 
        se = se, pvalues = p, gof.names = gof.names, gof = gof, 
        gof.decimal = gof.decimal)
    return(tr)
}

setMethod("extract", signature = className("lrm", "rms"), 
    definition = extract.lrm)
